import torch
import numpy as np

a = torch.randn(128, 1, 28, 28, 1)

b = torch.tensor([0, 2])

c = torch.reshape(a, [-1, 28*28])

data = torch.randn(3, 4)
print(data.argmax(dim=0))

